# -*- coding: utf-8 -*-
# pylint:disable=not-an-iterable, redefined-builtin

import json
import time
import uuid
from typing import Dict, List, Optional, Union

from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_stream_options_param import (
    ChatCompletionStreamOptionsParam,
)
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Annotated, Literal

from .agent_schemas import Role, Tool, FunctionCall


def generate_tool_call_id(prefix: str = "call_") -> str:
    #  generate a random uuid
    random_uuid = uuid.uuid4()
    # replace uuid to string and remove '-', then get latest 22 characters
    random_part = str(random_uuid).replace("-", "")[:22]
    # add prefix
    tool_call_id = f"{prefix}{random_part}"
    return tool_call_id


class ImageMessageContent(BaseModel):
    class ImageUrl(BaseModel):
        """
        Model class for image prompt message content.
        """

        url: str
        """Either a URL of the image or the base64 encoded image data."""

        detail: Literal["auto", "low", "high"] = "low"
        """Specifies the detail level of the image."""

    type: Literal["image_url"] = "image_url"
    """The type of the content part."""

    image_url: ImageUrl
    """The image URL details."""


class TextMessageContent(BaseModel):
    type: Literal["text"] = "text"
    """The type of the content part."""

    text: str
    """The text content."""


class AudioMessageContent(BaseModel):
    class InputAudioDetail(BaseModel):
        """
        Model class for audio prompt message content.
        """

        base64_data: str = Field(
            default="",
            description="the base64 data of multi-modal file",
        )
        """The base64 encoded audio data."""

        format: str = Field(
            default="mp3",
            description="The format of the encoded audio data.  supports "
            "'wav' and 'mp3'.",
        )
        """The format of the encoded audio data. Supports 'wav' and 'mp3'."""

        @property
        def data(self) -> str:
            return f"data:{self.format};base64,{self.base64_data}"

    type: Literal["input_audio"] = "input_audio"
    """The type of the content part."""

    input_audio: InputAudioDetail
    """The input audio details."""


ChatCompletionMessage = Annotated[
    Union[TextMessageContent, ImageMessageContent, AudioMessageContent],
    Field(discriminator="type"),
]


class ToolCall(BaseModel):
    """
    Model class for assistant prompt message tool call.
    """

    index: int = 0
    """The index of the tool call in the tool calls array."""

    id: str
    """The ID of the tool call."""

    type: Optional[str] = None
    """The type of the tool. Currently, only `function` is supported."""

    function: FunctionCall
    """The function that the model called."""


class OpenAIMessage(BaseModel):
    """
    Model class for prompt message.
    """

    role: str
    """The role of the messages author, should be in `user`,`system`,
    'assistant', 'tool'."""

    content: Optional[Union[str, List[ChatCompletionMessage]]] = None
    """The contents of the message.

    Can be a string, a list of content parts for multimodal messages.
    """

    name: Optional[str] = None
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

    tool_calls: Optional[List[ToolCall]] = None
    """The tool calls generated by the model, such as function calls."""

    def get_text_content(self) -> Optional[str]:
        """
        Extract the first text content from the message.

        :return: First text string found in the content, or None if no text
        content
        """
        if self.content is None:
            return None

        # Case 1: content is a simple string
        if isinstance(self.content, str):
            return self.content
        # Case 2: content is a list
        elif isinstance(self.content, list):
            for item in self.content:
                if hasattr(item, "type"):
                    if item.type == "text" and hasattr(item, "text"):
                        return item.text
        return None

    def get_image_content(self) -> List[str]:
        """
        Extract all image content (URLs or base64 data) from the message.

        :return: List of image URLs or base64 encoded strings found in the
        content
        """
        images = []

        if self.content is None:
            return images

        # Case 1: content is a simple string - no images
        if isinstance(self.content, str):
            return images
        # Case 2: content is a list
        elif isinstance(self.content, list):
            for item in self.content:
                if hasattr(item, "type"):
                    if item.type == "image_url" and hasattr(item, "image_url"):
                        if hasattr(item.image_url, "url"):
                            images.append(item.image_url.url)

        return images

    def get_audio_content(self) -> List[str]:
        """
        Extract all audio content (URLs or base64 data) from the message.

        :return: List of audio URLs or base64 encoded strings found in the
        content
        """
        audios = []

        if self.content is None:
            return audios

        # Case 1: content is a simple string - no audios
        if isinstance(self.content, str):
            return audios
        # Case 2: content is a list
        elif isinstance(self.content, list):
            for item in self.content:
                if hasattr(item, "type"):
                    if item.type == "input_audio" and hasattr(
                        item,
                        "input_audio",
                    ):
                        if hasattr(item.input_audio, "data"):
                            audios.append(item.input_audio.data)
                        elif hasattr(item.input_audio, "base64_data"):
                            # Construct data URL for audio
                            format_type = getattr(
                                item.input_audio,
                                "format",
                                "mp3",
                            )
                            audios.append(
                                f"data:{format_type};base64,"
                                f"{item.input_audio.base64_data}",
                            )

        return audios

    def has_multimodal_content(self) -> bool:
        """
        Check if the message contains multimodal content (images, audio,
        or video).

        :return: True if the message contains non-text content, False otherwise
        """
        return bool(
            self.get_image_content() or self.get_audio_content(),
        )

    def get_content_summary(self) -> Dict[str, int]:
        """
        Get a summary of different content types in the message.

        :return: Dictionary with counts of different content types
        """
        return {
            "text_count": 1 if self.get_text_content() is not None else 0,
            "image_count": len(self.get_image_content()),
            "audio_count": len(self.get_audio_content()),
        }


class UserMessage(OpenAIMessage):
    """
    Model class for user prompt message.
    """

    role: str = Role.USER
    """The role of the messages author, in this case `user`."""


class AssistantMessage(OpenAIMessage):
    """
    Model class for assistant prompt message.
    """

    role: str = Role.ASSISTANT
    """The role of the messages author, in this case `assistant`."""


class SystemMessage(OpenAIMessage):
    """
    Model class for system prompt message.
    """

    role: str = Role.SYSTEM
    """The role of the messages author, in this case `system`."""


class ToolMessage(OpenAIMessage):
    """
    Model class for tool prompt message.
    """

    role: str = Role.TOOL
    """The role of the messages author, in this case `tool`."""

    tool_call_id: str
    """Tool call that this message is responding to."""


class ResponseFormat(BaseModel):
    class JsonSchema(BaseModel):
        name: str
        """The name of the response format. """

        description: Union[str, None] = None
        """A description of what the response format is for, used by the
        model to determine how to respond in the format.
        """

        schema_param: dict = Field(None, alias="schema")
        """The schema for the response format, described as a JSON Schema
        object."""

        strict: Union[bool, None] = False
        """Whether to enable strict schema adherence when generating the output

        If set to true, the model will follow the exact schema defined in the
        `schema` field. Only a subset of JSON Schema is supported when `strict`
        is `true`. Learn more about Structured Outputs in the
        [function calling guide](docs/guides/function-calling).
        """

    type: Literal["text", "json_object", "json_schema"] = "text"
    """The type of response format being defined.

    - `text`: The default response format, which can be either text or any
    value needed.
    - `json_object`: Enables JSON mode, which guarantees the message the model
      generates is valid JSON.
    - `json_schema`: Enables Structured Outputs which guarantees the model will
      match your supplied JSON schema.
    """

    json_schema: Optional[JsonSchema] = None
    """The JSON schema for the response format."""

    @model_validator(mode="before")
    def validate_schema(self, values: dict) -> dict:
        if not isinstance(values, dict) or "type" not in values:
            raise ValueError(f"Json schema not valid with type {type(values)}")
        format_type = values.get("type")
        json_schema = values.get("json_schema")

        if format_type in ["text", "json_object"] and json_schema is not None:
            raise ValueError(
                f"Json schema is not allowed for type {format_type}",
            )

        if format_type == "json_schema":
            if json_schema is None:
                raise ValueError(
                    f"Json schema is required for type {format_type}",
                )
        return values


class ToolChoiceInputFunction(BaseModel):
    name: str
    """The name of the function to call."""


class ToolChoice(BaseModel):
    type: str
    """The type of the tool. Currently, only `function` is supported."""

    function: ToolChoiceInputFunction
    """The function that the model called."""


class Parameters(BaseModel):
    """
    General Parameters for LLM
    """

    top_p: Optional[float] = None
    """Nucleus sampling, between (0, 1.0],  where the model considers the
    results of the tokens with top_p probability  mass.

    So 0.1 means only the tokens comprising the top 10% probability mass are
    considered.

    We generally recommend altering this or `temperature` but not both.
    """

    temperature: Optional[float] = None
    """What sampling temperature to use, between 0 and 2.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused and deterministic.

    We generally recommend altering this or `top_p` but not both.
    """

    frequency_penalty: Optional[float] = None
    """Positive values penalize new tokens based on their existing frequency in
    the text so far, decreasing the model's likelihood to repeat the same line
    verbatim.

    """

    presence_penalty: Optional[float] = None
    """Number between -2.0 and 2.0.

    Positive values penalize new tokens based on whether they appear in the
    text so far, increasing the model's likelihood to talk about new topics.

    """

    max_tokens: Optional[int] = None
    """The maximum number of [tokens](/tokenizer) that can be generated in the
    chat completion.

    The total length of input tokens and generated tokens is limited by the
    model's context length.
    """

    stop: Optional[Union[Optional[str], List[str]]] = None
    """Up to 4 sequences where the API will stop generating further tokens."""

    stream: bool = True
    """If set, partial message deltas will be sent, like in ChatGPT. """

    stream_options: Optional[ChatCompletionStreamOptionsParam] = None
    """Options for streaming response. Only set this when you set
    `stream: true`."""

    tools: Optional[List[Union[Tool, Dict]]] = None
    """A list of tools the model may call.

    Currently, only functions are supported as a tool. Use this to provide a
    list of functions the model may generate JSON inputs for.
    """

    tool_choice: Optional[Union[str, ToolChoice]] = None
    """Controls which (if any) tool is called by the model.

    """

    parallel_tool_calls: bool = False
    """Whether to enable parallel function calling during tool use."""

    logit_bias: Optional[Dict[str, int]] = None
    """Modify the likelihood of specified tokens appearing in the completion.

    Accepts a JSON object that maps tokens (specified by their token ID in the
    tokenizer) to an associated bias value from -100 to 100. Mathematically,
    the bias is added to the logits generated by the model prior to
    sampling. The exact effect will vary per model, but values between -1
    and 1 should decrease or increase likelihood of selection; values like
    -100 or 100 should result in a ban or exclusive selection of the relevant
    token.
    """

    top_logprobs: Optional[int] = None
    """An integer between 0 and 20 specifying the number of most likely
    tokens to return at each token position, each with an associated log
    probability.

    `logprobs` must be set to `true` if this parameter is used.
    """

    logprobs: Optional[bool] = None
    """Whether to return log probabilities of the output tokens or not.

    If true, returns the log probabilities of each output token returned in the
    `content` of `message`.
    """

    n: Optional[int] = Field(default=1, ge=1, le=5)
    """How many chat completion choices to generate for each input message.

    Note that you will be charged based on the number of generated tokens
    across all of the choices. Keep `n` as `1` to minimize costs.
    """

    seed: Optional[int] = None
    """If specified, system will make a best effort to sample
    deterministically, such that repeated requests with the same `seed` and
    parameters should return the same result.
    """

    response_format: Optional[Union[ResponseFormat, str]] = ResponseFormat(
        type="text",
    )
    """An object specifying the format that the model must output.

    Setting to `{ "type": "json_object" }` enables JSON mode,
    which guarantees the message the model generates is valid JSON.
    """


def create_chat_completion(
    message: OpenAIMessage,
    model_name: str,
    id: str = "",
    finish_reason: Optional[str] = None,
) -> ChatCompletion:
    # Create Choice object
    choice = {
        "finish_reason": finish_reason,
        "index": 0,
        "message": message.model_dump(),
        "logprobs": None,
    }

    # Construct ChatCompletion object
    return ChatCompletion(
        id=id,  # Generate unique ID
        choices=[choice],  # List containing at least one Choice
        created=int(time.time()),  # Current timestamp
        model=model_name,  # Adjust based on actual model used
        object="chat.completion",  # Fixed literal value
        # Optional fields below
        service_tier=None,
        system_fingerprint=None,
        usage=None,
    )


def create_chat_completion_chunk(
    message: OpenAIMessage,
    model_name: str,
    id: str = "",
    finish_reason: Optional[str] = None,
) -> ChatCompletionChunk:
    # Create Choice object for chunk
    choice = {
        "finish_reason": finish_reason,
        "index": 0,
        "logprobs": None,
        "delta": message.model_dump(),
    }

    # Construct ChatCompletionChunk object
    return ChatCompletionChunk(
        id=id,  # Generate unique ID
        choices=[choice],  # List containing at least one Choice
        created=int(time.time()),  # Current timestamp
        model=model_name,  # Adjust based on actual model used
        object="chat.completion.chunk",  # Fixed literal value
        # Optional fields below
        service_tier=None,
        system_fingerprint=None,
        usage=None,
    )


def is_json_string(s: Union[str, Dict, BaseModel, None]) -> bool:
    try:
        obj = json.loads(s)  # type: ignore[arg-type]
        if isinstance(obj, (dict, list)):
            return True
        return False
    except Exception:
        return False
